import torch
import torch.nn as nn
import numpy as np
import random

from data_loader_imagenet import *
from torchvision.models import resnet18, resnet50
from scipy.linalg import orth

from imagenet_delete_add_hierachy import file_delete, file_add
import sys
import csv
import argparse

# Keep penultimate features as global varialble such that hook modifies these features
penultimate_fts = None


def get_penultimate_fts(self, input, output):
    global penultimate_fts
    penultimate_fts = output
    return None


def load_model_kd(alpha, temperature):
    """
    :return: model loaded with trained weights
    """

    # load model
    if args.model == 'resnet18':
        model = resnet18()
    else:
        model = resnet50()
    path = f'./output/checkpoints/{args.model}-t=resnet50-a={alpha}-T={temperature}.pth.tar'
    ckpt = torch.load(path)['state_dict']  # 320
    ckpt = {k.replace('module.', ''): v for k, v in ckpt.items() if k.startswith('module')}
    model.load_state_dict(ckpt)

    return model, ckpt, path


def visualize(model, dataloader, num_sample=100, category=None):
    """
    :param dataloader: data_loader
    :return: visualize global features of train/valid samples
    """
    label_array = []
    feature_array = []
    model.cuda().eval()
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.cuda(), y.cuda()

        # =================== extract penultimate layer features =======================
        # Register hook to avg pool layer
        model.avgpool.register_forward_hook(get_penultimate_fts)
        with torch.no_grad():
            output = model(x)
            assert torch.is_tensor(penultimate_fts)
        # re-sampling
        feature = penultimate_fts.data.cpu().numpy().squeeze()
        label = y.data.cpu().numpy()
        feature_array.append(feature)
        label_array.append(label[:, np.newaxis])

    output_array = np.concatenate(feature_array, axis=0)
    target_array = np.concatenate(label_array, axis=0)
    output_subset = []
    target_subset = []
    for i in category:
        sample_index = np.arange(num_sample)  # sample the same 100 smaples of all cases
        tmp_index = np.where(target_array == i)[0][sample_index]  # sample 100 features from the same class
        output_subset.append(output_array[tmp_index])
        target_subset.append(target_array[tmp_index])
        # print("Get one class features")
    output_subset_concat = np.concatenate(output_subset, axis=0)
    target_subset_concat = np.concatenate(target_subset, axis=0)
    print('Feature Shape :', output_subset_concat.shape)
    print('Target Shape :', target_subset_concat.shape)

    return output_subset_concat, target_subset_concat


def relative_distance(features, path, num_class, num_sample=100,
                      distance_type="l2", title=None, print_table=None):
    """
    :param features:
    :param path:
    :param num_class:  # of class in total
    :param num_sample: # of samples per class for calculation
    :param distance_type:  # L2/L1 distance, default is L2
    :param title:
    :param print_table: # whether save the table in csv file
    :return:
    """

    # step 1. divide the features in group
    feature_dict = {}
    for key in range(num_class):
        feature_dict[key] = np.array(features[key * num_sample: (key + 1) * num_sample])  # (100, 4096) per class

    # step 2. calculate the mean for each class
    class_centroids = []
    for key in range(num_class):
        assert isinstance(feature_dict[key], np.ndarray)
        class_centroids.append(feature_dict[key].mean(axis=0))

    # step 3. calculate the prototype distance
    distance = []
    if distance_type == "l2":
        for i in range(num_class):
            for j in range(num_class):
                distance.append(np.linalg.norm(class_centroids[i] - class_centroids[j], ord=2))
    elif distance_type == "l1":
        for i in range(num_class):
            for j in range(num_class):
                distance.append(np.abs(class_centroids[i] - class_centroids[j]).sum())
    distance = np.array(distance).reshape((num_class, num_class))

    # step 4. calculate the relative distance
    distance_sum = []

    for i in range(len(category)):
        distance_sum.append(distance[i].sum())
    distance_relative = []
    for j in range(len(category)):
        for k in range(len(category)):
            distance_relative.append(distance[j][k] / distance_sum[j])
    distance_relative = np.array(distance_relative).reshape((num_class, num_class))

    if not print_table:
        return distance_relative
    else:
        model_path = path[:-4]
        filename = f"{model_path}_eta_{distance_type}_distance.csv"
        title.insert(0, '')

        # writing to csv file
        with open(filename, 'w', newline='') as csvfile:
            # creating a csv writer object
            csvwriter = csv.writer(csvfile)
            # writing the fields
            csvwriter.writerow(title)
            # writing the data rows
            for i in range(len(title) - 1):
                row = list(distance_relative[i])
                row.insert(0, title[i + 1])
                csvwriter.writerow(row)


def compute_eta(distance_array, num_class_total=25, num_class_similar=5):
    """
    :param num_class_similar: number of semantic similar classes
    :param num_class_total: number of classes in total
    :param distance_array: relative distance array when T=1 and T=3
    :return: the eta value, in a numpy form
    """
    # step 1. get init info of the input
    array_1, array_2 = distance_array[0], distance_array[1]
    final_array = []

    # step 2. compute entries in the final array
    for m in range(num_class_total):
        for n in range(num_class_total):
            if m != n:
                final_array.append((array_2[m][n] - array_1[m][n])/array_1[m][n])
            else:
                final_array.append(0)
    final_array = np.array(final_array).reshape(num_class_total, num_class_total)

    # step 3. compute the eta per row
    eta = np.zeros([num_class_similar, 2])
    for k in range(num_class_similar):
        eta[k][0] = final_array[k][0:num_class_similar].mean() * 100
        eta[k][1] = final_array[k][num_class_similar:].mean() * 100

    return eta


def argparser():
    parser = argparse.ArgumentParser(description="Visualization of LS-KD features")
    parser.add_argument('--batch_size', default=128)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = argparser()

    # dataset settings
    category = np.arange(25)
    category_similar = 5
    hierarchy_path = '../data/imagenet_visualization_hierarchy'
    # original class name and id

    # set
    CLASSES_original = ['Chesapeake_Bay_retriever', 'curly-coated_retriever', 'flat-coated_retriever', 'golden_retriever', 'Labrado_retriever',
                        'speedboat', 'tub', 'minivan', 'umbrella', 'palace',
                        'magnetic_compass', 'Tibetan_mastiff', 'Arctic_fox', 'monarch', 'American_lobster',
                        'revolver', 'puck', 'orange', 'tank', 'ambulance',
                        'pay-phone', 'lemon', 'comic_book', 'Arabian_camel', 'volcano']
    
    # renamed to keep order
    CLASSES_id = ['n00000000', 'n00000001', 'n00000002', 'n00000003', 'n00000004',
                  'n04273569', 'n04493381', 'n03770679', 'n04507155', 'n03877845',
                  'n03706229', 'n02108551', 'n02120079', 'n02279972', 'n01983481',
                  'n04086273', 'n04019541', 'n07747607', 'n04389033', 'n02701002',
                  'n03902125', 'n07749582', 'n06596364', 'n02437312', 'n09472597']

    CLASSES_id_original = CLASSES_id.copy()

    # sort the class name and id
    CLASSES_id.sort()
    CLASSES = []
    for i, id in enumerate(CLASSES_id):
        CLASSES.append(CLASSES_original[CLASSES_id_original.index(id)])
    print(CLASSES_id)
    print(CLASSES)

    # to be filled in the table
    fields = []
    for name in category:
        fields.append(CLASSES[name])

    # --------- Step 1. get the dataloader ---------- #
    train_loader, valid_loader = get_train_valid_loader(data_dir=hierarchy_path,
                                                        batch_size=args.batch_size, augment=True,
                                                        shuffle=False)
    # ---------------------------------------------- #
    args.model = 'resnet18' # resnet18 or resnet50
    alpha = ['0.0', '0.1']
    temperature = ['1', '3'] 

    # --------- Step 2. extract the features and statistics ---------- #
    # KD resnet 18/50
    row_wise_distance = []
    for visual_set in [train_loader]:
        a = '0.1'
        if visual_set == train_loader:
            args.num_sample = 1000
        else:
            args.num_sample = 50
        for i, temp in enumerate(temperature):
            model, state, path = load_model_kd(alpha=a, temperature=temp)

            # ------- Step 3. Feature Extraction ------- #
            output_feature, output_target = visualize(model=model, dataloader=visual_set,
                                                      num_sample=args.num_sample,
                                                      category=category)  # (300, 2048)

            # ------- Step 4. Get the relative distance statistics ------- #
            tmp = relative_distance(features=output_feature, path=path,
                                    num_class=len(category), distance_type='l1', num_sample=args.num_sample,
                                    title=fields.copy(), print_table=False)  # do not save csv file
            row_wise_distance.append(tmp)
    row_wise_distance = np.array(row_wise_distance)
    print(row_wise_distance.shape)  # should be (2, num_class, num_class)

    # Step 5. Obtain eta value
    eta = compute_eta(row_wise_distance, num_class_total=len(category), num_class_similar=category_similar)
    print(eta)


